Running attribution!
Input: (tensor([[[[-3.8202e-02, -1.4827e-01, 1.0766e-01, ..., 1.3316e-01,
-9.8657e-02, -6.1993e-02],
[ 1.7961e-02, -7.9668e-02, -4.2784e-02, ..., 1.6037e-03,
2.6998e-02, 5.7171e-02],
[ 4.7301e-02, -2.6343e-02, -8.6350e-02, ..., 4.2694e-02,
-6.2673e-02, -7.3593e-03],
...,
[ 5.6734e-02, 1.3328e-02, 1.5833e-03, ..., -1.0371e-01,
3.4959e-02, 3.3161e-02],
[-1.3501e-01, 2.5231e-01, 1.2611e-01, ..., -2.9068e-02,
-1.3579e-01, 7.7787e-02],
[ 3.4632e-02, 2.9779e-02, -5.6488e-04, ..., 3.8829e-02,
-4.0786e-02, 1.3188e-01]],
[[-4.2960e-02, -1.4648e-01, 1.1146e-01, ..., 1.3758e-01,
-9.8202e-02, -5.6699e-02],
[ 2.0101e-02, -9.2041e-02, -4.7995e-02, ..., -7.4466e-03,
2.3998e-02, 5.9188e-02],
[ 6.2252e-02, -1.3594e-02, -8.9068e-02, ..., 5.0048e-02,
-6.6885e-02, 7.8406e-03],
...,
[ 5.8222e-02, 1.4373e-02, 5.5309e-03, ..., -1.0639e-01,
3.1540e-02, 3.6838e-02],
[-1.3131e-01, 2.4665e-01, 1.3239e-01, ..., -4.0621e-02,
-1.5014e-01, 5.5906e-02],
[ 3.5457e-02, 3.1159e-02, -6.4574e-04, ..., 3.9075e-02,
-4.1930e-02, 1.3077e-01]],
[[-4.0339e-02, -1.9827e-01, 1.2731e-01, ..., 8.5587e-02,
-1.4614e-01, -4.5109e-02],
[ 1.7036e-02, -8.4320e-02, -4.3065e-02, ..., 4.8828e-04,
2.5040e-02, 5.5996e-02],
[ 4.7582e-02, -3.9674e-02, -1.1146e-01, ..., 4.1864e-02,
-3.3406e-02, 1.5174e-02],
...,
[ 6.1540e-02, 9.5445e-03, -1.0269e-05, ..., -1.0607e-01,
3.4504e-02, 3.2050e-02],
[-1.2639e-01, 2.4143e-01, 1.3781e-01, ..., -5.2199e-02,
-1.6489e-01, 3.4250e-02],
[ 3.4365e-02, 2.8884e-02, 6.8105e-05, ..., 3.9877e-02,
-4.2655e-02, 1.2899e-01]],
...,
[[-8.2595e-02, -1.9998e-01, 1.5490e+00, ..., 3.3959e-01,
1.4255e-01, 6.5081e-01],
[ 1.1554e-02, -4.9596e-02, -3.0127e-02, ..., -4.3610e-03,
4.8181e-03, 5.1586e-02],
[ 1.4618e-01, -1.2103e-01, 6.3858e-02, ..., 1.0316e-01,
-2.4471e-01, -8.5784e-02],
...,
[ 8.2703e-02, 6.1125e-03, 1.1424e-03, ..., -1.1052e-01,
3.7004e-02, 4.8783e-02],
[-9.3220e-02, 2.7190e-01, 1.3811e-01, ..., -5.3811e-04,
-1.6357e-01, 8.1474e-02],
[ 5.6859e-02, -2.5700e-02, 3.6478e-02, ..., 6.3215e-02,
-6.2333e-02, 1.0075e-01]],
[[-3.0842e-02, -1.8030e-01, 3.5580e-01, ..., 1.7976e-01,
-1.5684e-01, 7.9073e-02],
[ 1.8380e-02, -6.1629e-02, -3.7418e-02, ..., -1.2335e-02,
1.8126e-02, 4.8978e-02],
[ 9.5952e-02, -3.6483e-02, -3.4110e-02, ..., 8.2476e-02,
-1.3885e-01, -5.9304e-03],
...,
[ 5.7276e-02, 1.3689e-02, 1.6319e-02, ..., -1.0419e-01,
4.6597e-02, 5.2112e-02],
[ 6.9066e-01, 6.7834e-01, 5.1696e-01, ..., 3.9500e-01,
-6.4991e-01, 2.7411e-01],
[ 5.0142e-02, 1.7234e-02, 1.9676e-02, ..., 5.7096e-02,
-8.4455e-02, 9.9966e-02]],
[[-1.5154e-01, -4.0557e-01, 2.0662e-01, ..., -1.0168e-01,
-2.6156e-01, 6.3641e-02],
[ 1.9388e-02, -2.3396e-01, 7.4004e-02, ..., -2.8685e-02,
-1.8736e-02, 2.0577e-02],
[ 1.4966e-01, -1.5743e-01, 9.5336e-02, ..., 1.3301e-01,
-4.4521e-01, 9.4549e-02],
...,
[ 3.3283e-01, -3.2398e-01, -2.7622e-01, ..., -8.3556e-01,
-1.4135e+00, 4.9114e-01],
[-3.9389e-02, 2.9505e-01, 1.6275e-01, ..., -3.3018e-03,
-1.9343e-01, 8.3654e-02],
[ 1.8233e-01, -2.0033e-01, 1.1411e-01, ..., 1.0744e-01,
-1.3532e-01, -1.7710e-02]]]], device='mps:0', requires_grad=True),)
Inputs layer: (tensor([[[[ 6.7257e-02, 2.1852e-02, 7.4130e-02, ..., -4.0783e-02,
1.4342e-01, 1.3830e-01],
[ 1.8222e-02, -1.1721e-02, 2.2810e-02, ..., 1.1000e-02,
-7.0729e-02, -8.8895e-03],
[-4.9888e-02, -5.0902e-03, 3.6127e-02, ..., -5.3131e-02,
1.1598e-01, -6.6142e-02],
...,
[-1.1046e-02, -1.0265e-01, 3.1398e-02, ..., 6.6149e-02,
6.3570e-02, -8.9751e-02],
[-1.8324e-01, 6.2830e-02, 4.6200e-02, ..., 1.7861e-01,
-8.4179e-02, -5.3105e-02],
[-5.5550e-02, 2.6784e-01, -1.6964e-01, ..., -2.2132e-02,
2.0114e-01, 4.2940e-02]],
[[ 6.7305e-02, 2.3125e-02, 6.8377e-02, ..., -4.0580e-02,
1.3791e-01, 1.4491e-01],
[ 9.7610e-03, 2.1347e-04, 1.3024e-02, ..., 9.4734e-03,
-7.2768e-02, -1.4875e-02],
[-4.4417e-02, -7.9859e-03, 3.4506e-02, ..., -5.8396e-02,
9.1113e-02, -6.8698e-02],
...,
[-9.1560e-03, -9.9715e-02, 2.6937e-02, ..., 7.1029e-02,
6.8079e-02, -9.4438e-02],
[-1.6535e-01, 6.1141e-02, 2.8962e-02, ..., 1.6938e-01,
-6.5112e-02, -7.1129e-02],
[-5.4615e-02, 2.6768e-01, -1.6892e-01, ..., -2.1316e-02,
2.0000e-01, 4.3236e-02]],
[[ 1.6413e-02, 2.6860e-03, 3.5159e-02, ..., -1.7897e-02,
5.8670e-02, 1.1333e-01],
[ 1.8434e-02, -1.0882e-02, 1.9270e-02, ..., 1.0286e-02,
-7.4354e-02, -1.2423e-02],
[-6.2168e-02, 1.0135e-02, 4.4744e-02, ..., -6.3592e-02,
1.1222e-01, -9.8139e-02],
...,
[-1.4830e-02, -1.0138e-01, 3.1198e-02, ..., 6.5358e-02,
7.1195e-02, -9.2705e-02],
[-1.4656e-01, 5.7309e-02, 1.2572e-02, ..., 1.6178e-01,
-4.8725e-02, -9.1210e-02],
[-5.3958e-02, 2.6912e-01, -1.6539e-01, ..., -2.2252e-02,
1.9970e-01, 4.1960e-02]],
...,
[[-6.5542e-01, 7.3925e-01, 3.5380e-01, ..., 4.7193e-01,
-1.0136e-01, 1.0693e-01],
[ 3.4425e-02, -1.6041e-02, -4.5687e-03, ..., 3.1169e-02,
-8.3258e-02, -8.0656e-03],
[-2.4455e-02, -3.1810e-02, -5.1973e-02, ..., -5.2291e-03,
3.9967e-02, -1.8897e-01],
...,
[-2.3257e-02, -9.1931e-02, 2.3208e-02, ..., 5.8166e-02,
7.9147e-02, -8.4748e-02],
[-1.6788e-01, 9.8345e-03, 3.8483e-02, ..., 1.8770e-01,
-7.3330e-02, -4.1527e-02],
[-7.0851e-02, 2.8489e-01, -1.5013e-01, ..., -2.6787e-03,
1.7639e-01, 1.6531e-02]],
[[-8.3777e-02, 1.5804e-01, 9.2494e-02, ..., 3.0265e-02,
1.4602e-01, 1.2322e-01],
[ 1.7982e-02, -9.3206e-03, -5.2507e-03, ..., 1.5176e-02,
-7.3022e-02, -1.3481e-02],
[-5.2448e-02, -1.5984e-02, 7.6266e-03, ..., -8.1312e-03,
5.1856e-02, -1.0552e-01],
...,
[-2.5957e-02, -1.0190e-01, 1.0791e-02, ..., 5.9728e-02,
1.0416e-01, -8.8054e-02],
[ 1.3211e-01, -1.0156e+00, -1.5681e-01, ..., 4.9979e-01,
1.8641e-01, 2.7841e-01],
[-8.6638e-02, 2.6933e-01, -1.7620e-01, ..., -4.6207e-03,
2.0693e-01, 1.3487e-02]],
[[-6.0754e-02, -7.4416e-02, 2.6250e-01, ..., 3.2920e-02,
-8.5164e-02, -4.1713e-02],
[ 1.4581e-02, 1.7219e-02, -3.4395e-02, ..., 1.5486e-01,
6.6413e-03, -1.0352e-01],
[ 2.3665e-03, -1.1250e-01, -2.0316e-01, ..., 1.0723e-01,
1.6910e-04, -2.5264e-01],
...,
[ 1.0065e+00, -6.3978e-01, -4.6803e-01, ..., -5.7909e-01,
-1.0093e+00, 2.4856e+00],
[-1.5921e-01, -3.6874e-02, -9.1598e-03, ..., 1.7008e-01,
-6.8307e-02, 2.2644e-02],
[-1.2472e-02, 4.0908e-01, -1.6936e-02, ..., 2.9166e-02,
4.5930e-02, -9.6726e-02]]]], device='mps:0'),)
Baseline input: (tensor([[[[-3.8202e-02, -1.4827e-01, 1.0766e-01, ..., 1.3316e-01,
-9.8657e-02, -6.1993e-02],
[ 1.7961e-02, -7.9668e-02, -4.2784e-02, ..., 1.6037e-03,
2.6998e-02, 5.7171e-02],
[ 4.7301e-02, -2.6343e-02, -8.6350e-02, ..., 4.2694e-02,
-6.2673e-02, -7.3593e-03],
...,
[ 5.6734e-02, 1.3328e-02, 1.5833e-03, ..., -1.0371e-01,
3.4959e-02, 3.3161e-02],
[-1.3501e-01, 2.5231e-01, 1.2611e-01, ..., -2.9068e-02,
-1.3579e-01, 7.7787e-02],
[ 3.4632e-02, 2.9779e-02, -5.6488e-04, ..., 3.8829e-02,
-4.0786e-02, 1.3188e-01]],
[[-4.2960e-02, -1.4648e-01, 1.1146e-01, ..., 1.3758e-01,
-9.8202e-02, -5.6699e-02],
[ 2.0101e-02, -9.2041e-02, -4.7995e-02, ..., -7.4466e-03,
2.3998e-02, 5.9188e-02],
[ 6.2252e-02, -1.3594e-02, -8.9068e-02, ..., 5.0048e-02,
-6.6885e-02, 7.8406e-03],
...,
[ 5.8222e-02, 1.4373e-02, 5.5309e-03, ..., -1.0639e-01,
3.1540e-02, 3.6838e-02],
[-1.3131e-01, 2.4665e-01, 1.3239e-01, ..., -4.0621e-02,
-1.5014e-01, 5.5906e-02],
[ 3.5457e-02, 3.1159e-02, -6.4574e-04, ..., 3.9075e-02,
-4.1930e-02, 1.3077e-01]],
[[-4.0339e-02, -1.9827e-01, 1.2731e-01, ..., 8.5587e-02,
-1.4614e-01, -4.5109e-02],
[ 1.7036e-02, -8.4320e-02, -4.3065e-02, ..., 4.8828e-04,
2.5040e-02, 5.5996e-02],
[ 4.7582e-02, -3.9674e-02, -1.1146e-01, ..., 4.1864e-02,
-3.3406e-02, 1.5174e-02],
...,
[ 6.1540e-02, 9.5445e-03, -1.0269e-05, ..., -1.0607e-01,
3.4504e-02, 3.2050e-02],
[-1.2639e-01, 2.4143e-01, 1.3781e-01, ..., -5.2199e-02,
-1.6489e-01, 3.4250e-02],
[ 3.4365e-02, 2.8884e-02, 6.8105e-05, ..., 3.9877e-02,
-4.2655e-02, 1.2899e-01]],
...,
[[-6.4666e-02, -2.2279e-01, 1.4367e+00, ..., 3.0079e-01,
9.3913e-02, 5.7781e-01],
[ 1.4685e-02, -4.4217e-02, -2.7132e-02, ..., -1.3492e-03,
1.2704e-02, 5.5928e-02],
[ 1.4704e-01, -1.1678e-01, 6.8219e-02, ..., 9.6129e-02,
-2.7377e-01, -9.1615e-02],
...,
[ 8.0699e-02, 8.0104e-03, 3.3862e-03, ..., -1.1085e-01,
3.6471e-02, 5.0242e-02],
[-9.3277e-02, 2.6532e-01, 1.3590e-01, ..., 6.0855e-04,
-1.6453e-01, 8.7835e-02],
[ 6.1304e-02, -2.9531e-02, 4.1058e-02, ..., 6.2705e-02,
-6.3086e-02, 1.0173e-01]],
[[-2.6884e-02, -1.8591e-01, 3.3385e-01, ..., 1.7805e-01,
-1.5692e-01, 6.5920e-02],
[ 1.9038e-02, -5.7073e-02, -3.5611e-02, ..., -1.0744e-02,
2.1955e-02, 5.1898e-02],
[ 9.1345e-02, -3.4546e-02, -3.6675e-02, ..., 7.9345e-02,
-1.3974e-01, -1.0724e-02],
...,
[ 5.6637e-02, 1.3481e-02, 1.8984e-02, ..., -1.0618e-01,
4.8430e-02, 5.3521e-02],
[ 7.4771e-01, 6.1334e-01, 4.9233e-01, ..., 5.3850e-01,
-6.4046e-01, 4.2198e-01],
[ 5.6981e-02, 1.5577e-02, 2.4973e-02, ..., 6.1298e-02,
-9.0802e-02, 1.0218e-01]],
[[-1.7451e-02, -3.1113e-01, 2.0889e-01, ..., -2.4565e-02,
-2.5259e-01, 5.4470e-02],
[-1.1172e-02, -2.2590e-01, -1.4908e-04, ..., -4.5345e-02,
-4.1357e-02, 5.9049e-03],
[ 9.3422e-02, -5.8440e-02, -2.7259e-02, ..., 7.9603e-02,
-1.8869e-01, 1.2196e-02],
...,
[ 1.8930e+00, -1.6141e+00, -9.7022e-01, ..., -8.6126e-01,
9.9103e-02, -6.9525e-01],
[ 6.1128e-03, 2.8894e-01, 1.4076e-01, ..., 4.6594e-02,
-2.0995e-01, 1.0698e-01],
[ 2.0692e-01, -2.3148e-01, 1.4961e-01, ..., 1.1938e-01,
-1.5921e-01, -2.1903e-02]]]], device='mps:0', requires_grad=True),)
Baselines layer: (tensor([[[[ 6.7257e-02, 2.1852e-02, 7.4130e-02, ..., -4.0783e-02,
1.4342e-01, 1.3830e-01],
[ 1.8222e-02, -1.1721e-02, 2.2810e-02, ..., 1.1000e-02,
-7.0729e-02, -8.8895e-03],
[-4.9888e-02, -5.0902e-03, 3.6127e-02, ..., -5.3131e-02,
1.1598e-01, -6.6142e-02],
...,
[-1.1046e-02, -1.0265e-01, 3.1398e-02, ..., 6.6149e-02,
6.3570e-02, -8.9751e-02],
[-1.8324e-01, 6.2830e-02, 4.6200e-02, ..., 1.7861e-01,
-8.4179e-02, -5.3105e-02],
[-5.5550e-02, 2.6784e-01, -1.6964e-01, ..., -2.2132e-02,
2.0114e-01, 4.2940e-02]],
[[ 6.7305e-02, 2.3125e-02, 6.8377e-02, ..., -4.0580e-02,
1.3791e-01, 1.4491e-01],
[ 9.7610e-03, 2.1347e-04, 1.3024e-02, ..., 9.4734e-03,
-7.2768e-02, -1.4875e-02],
[-4.4417e-02, -7.9859e-03, 3.4506e-02, ..., -5.8396e-02,
9.1113e-02, -6.8698e-02],
...,
[-9.1560e-03, -9.9715e-02, 2.6937e-02, ..., 7.1029e-02,
6.8079e-02, -9.4438e-02],
[-1.6535e-01, 6.1141e-02, 2.8962e-02, ..., 1.6938e-01,
-6.5112e-02, -7.1129e-02],
[-5.4615e-02, 2.6768e-01, -1.6892e-01, ..., -2.1316e-02,
2.0000e-01, 4.3236e-02]],
[[ 1.6413e-02, 2.6860e-03, 3.5159e-02, ..., -1.7897e-02,
5.8670e-02, 1.1333e-01],
[ 1.8434e-02, -1.0882e-02, 1.9270e-02, ..., 1.0286e-02,
-7.4354e-02, -1.2423e-02],
[-6.2168e-02, 1.0135e-02, 4.4744e-02, ..., -6.3592e-02,
1.1222e-01, -9.8139e-02],
...,
[-1.4830e-02, -1.0138e-01, 3.1198e-02, ..., 6.5358e-02,
7.1195e-02, -9.2705e-02],
[-1.4656e-01, 5.7309e-02, 1.2572e-02, ..., 1.6178e-01,
-4.8725e-02, -9.1210e-02],
[-5.3958e-02, 2.6912e-01, -1.6539e-01, ..., -2.2252e-02,
1.9970e-01, 4.1960e-02]],
...,
[[-6.2515e-01, 6.9672e-01, 3.5210e-01, ..., 4.6784e-01,
-1.0845e-01, 5.3174e-02],
[ 3.6308e-02, -1.6594e-02, 1.4097e-03, ..., 4.1001e-02,
-6.7964e-02, -6.1845e-03],
[-1.4689e-02, -4.3328e-02, -6.1983e-02, ..., -1.4398e-03,
4.2181e-02, -1.8774e-01],
...,
[-2.1084e-02, -9.3169e-02, 2.3712e-02, ..., 6.0094e-02,
7.7967e-02, -8.1557e-02],
[-1.7266e-01, 1.2970e-02, 4.1380e-02, ..., 1.9077e-01,
-7.0304e-02, -3.9218e-02],
[-7.0699e-02, 2.8755e-01, -1.4869e-01, ..., -1.1955e-03,
1.7067e-01, 1.5049e-02]],
[[-7.5672e-02, 1.5290e-01, 8.8990e-02, ..., 3.2046e-02,
1.3982e-01, 1.1454e-01],
[ 2.0688e-02, -1.0264e-02, -3.5922e-03, ..., 2.3380e-02,
-6.1244e-02, -1.3335e-02],
[-5.0905e-02, -1.6117e-02, 9.7898e-03, ..., -1.3327e-02,
5.7764e-02, -1.0200e-01],
...,
[-2.6266e-02, -9.9956e-02, 1.2263e-02, ..., 6.0819e-02,
1.0602e-01, -8.8428e-02],
[ 7.3145e-02, -1.0733e+00, -1.1518e-01, ..., 5.8111e-01,
2.5014e-01, 2.9084e-01],
[-8.6538e-02, 2.7388e-01, -1.7713e-01, ..., -3.3255e-03,
2.0242e-01, 9.7987e-03]],
[[-6.8028e-02, 2.0828e-02, 7.9210e-02, ..., 2.9539e-02,
-8.7750e-02, -2.8678e-02],
[ 3.0150e-02, 1.6722e-02, -8.9075e-02, ..., 5.4656e-02,
-1.4116e-01, -1.2706e-01],
[-2.3913e-02, -3.1963e-02, -2.1303e-02, ..., -1.7024e-02,
7.7753e-02, -1.5626e-01],
...,
[-1.6762e+00, 1.4804e-01, 2.9628e-01, ..., -6.5625e-01,
2.7657e+00, -8.3868e-01],
[-1.8645e-01, -9.2984e-02, -6.1550e-03, ..., 2.0976e-01,
-8.0872e-02, 5.3476e-02],
[-6.3047e-03, 4.6148e-01, 4.4118e-02, ..., 3.5672e-02,
-4.6170e-02, -1.3669e-01]]]], device='mps:0'),)
Attribution score for head (9, 6): tensor([-4.5160e-07])
Matches previous IG score: [ True]